#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors

{%- set mdesc = "ssd" if ssd else "split" %}
{%- set sdesc = "_ssd" if ssd else "" %}

import torch
{%- if is_experimental_optimizer %}
import warnings
{%- endif %}
from .lookup_args{{ sdesc }} import *

{%- if is_fbcode %}

from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
# Provide compatibility to downstream packages for eventual migration to the split training / inference packages
try:
    load_torch_module(
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_gpu",
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training",
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training",
    )
    load_torch_module_bc(
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_cpu",
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training",
    )
except Exception:
    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops")
    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu")

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")

{%- endif %}

{%- if is_prototype_optimizer %}
# Decorate the prototype optimizers which may be deprecated in the future with jit.ignore to avoid
# possible errors from torch.jit.script. 
# Note that backends can be removed but the lookup invoker is still needed for backward compatibility
@torch.jit.ignore
{%- endif %}
def invoke(
    common_args: CommonArgs,
    optimizer_args: OptimizerArgs,
    {%- if "momentum1_dev" in args.split_function_arg_names %}
    momentum1: Momentum,
    {%- endif %}
    {%- if "momentum2_dev" in args.split_function_arg_names %}
    momentum2: Momentum,
    {%- endif %}
    {%- if "prev_iter_dev" in args.split_function_arg_names %}
    prev_iter: Momentum,
    {%- endif %}
    {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
    row_counter: Momentum,
    {%- endif %}
    {%- if "iter" in args.split_function_arg_names %}
    iter: int,
    {%- endif %}
    {%- if "max_counter" in args.split_function_arg_names %}
    max_counter: float,
    {%- endif %}
    {%- if "total_unique_indices" in args.split_function_arg_names %}
    total_unique_indices: int,
    {%- endif %}
    {%- if "iter" not in args.split_function_arg_names %}
    iter: int = 0,
    {%- endif %}
    apply_global_weight_decay: bool = False,
    {%- if "prev_iter_dev" not in args.split_function_arg_names %}
    # only pass prev_iter_dev since prev_iter is never created on UVM
    prev_iter_dev: Optional[torch.Tensor] = None,
    {%- endif %}
    gwd_lower_bound: float = 0.0,
    {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
    row_counter: Optional[Momentum] = None,
    {%- endif %}
) -> torch.Tensor:
    {%- if is_experimental_optimizer %}
    # By design, the warning only shows up once
    warnings.warn(
        f"""\033[93m
        [FBGEMM_GPU] NOTE: The training optimizer '{{ optimizer }}' is marked as
        EXPERIMENTAL and thus not optimized, in order to reduce code compilation
        times and build sizes!
        \033[0m"""
    )
    {%- endif %}

    vbe_metadata = common_args.vbe_metadata
    {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
    if not optimizer_args.use_rowwise_bias_correction or row_counter is None:
        row_counter_dev = None
        row_counter_uvm = None
        row_counter_offsets = None
        row_counter_placements = None
    elif optimizer_args.use_rowwise_bias_correction and row_counter is None:
        assert False, "use_rowwise_bias_correction is set but row_counter cannot be None"
    else:
        row_counter_dev = row_counter.dev
        row_counter_uvm = row_counter.uvm
        row_counter_offsets = row_counter.offsets
        row_counter_placements = row_counter.placements    
    {%- endif %}
    {%- if has_cpu_support and not ssd %}
    if (common_args.host_weights.numel() > 0):
        T = common_args.D_offsets.numel() - 1
        vbe: bool = vbe_metadata.B_offsets is not None
        if vbe:
            # create offsets with fixed batch size max_B
            # not efficient but for now we just need a functional implementation for CPU
            max_B = vbe_metadata.max_B
            offsets = torch.empty([T * max_B + 1], dtype=common_args.offsets.dtype, device=common_args.offsets.device)
            for t in range(T):
                B_offsets = vbe_metadata.B_offsets
                assert isinstance(B_offsets, torch.Tensor)
                begin = B_offsets[t]
                end = B_offsets[t + 1]
                offsets[t * max_B : t * max_B + end - begin] = common_args.offsets[begin : end]
                offsets[t * max_B + end - begin : (t + 1) * max_B] = common_args.offsets[end]
            offsets[-1] = common_args.offsets[-1]
        else:
            offsets = common_args.offsets
        output = torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
            # common_args
            host_weights=common_args.host_weights,
            weights_placements=common_args.weights_placements,
            weights_offsets=common_args.weights_offsets,
            D_offsets=common_args.D_offsets,
            total_D=common_args.total_D,
            max_D=common_args.max_D,
            hash_size_cumsum=common_args.hash_size_cumsum,
            total_hash_size_bits=common_args.total_hash_size_bits,
            indices=common_args.indices,
            offsets=offsets,
            pooling_mode=common_args.pooling_mode,
            indice_weights=common_args.indice_weights,
            feature_requires_grad=common_args.feature_requires_grad,
            # optimizer_args
            gradient_clipping = optimizer_args.gradient_clipping,
            max_gradient=optimizer_args.max_gradient,
            stochastic_rounding=optimizer_args.stochastic_rounding,
            {%- if "learning_rate" in args.split_function_args_v1 %}
            learning_rate=optimizer_args.learning_rate,
            {%- endif %}
            {%- if "eps" in args.split_function_arg_names %}
            eps=optimizer_args.eps,
            {%- endif %}
            {%- if "beta1" in args.split_function_arg_names %}
            beta1=optimizer_args.beta1,
            {%- endif %}
            {%- if "beta2" in args.split_function_arg_names %}
            beta2=optimizer_args.beta2,
            {%- endif %}
            {%- if "weight_decay" in args.split_function_arg_names %}
            weight_decay=optimizer_args.weight_decay,
            {%- endif %}
            {%- if "weight_decay_mode" in args.split_function_arg_names %}
            weight_decay_mode=optimizer_args.weight_decay_mode,
            {%- endif %}
            {%- if "eta" in args.split_function_arg_names %}
            eta=optimizer_args.eta,
            {%- endif %}
            {%- if "momentum" in args.split_function_arg_names %}
            momentum=optimizer_args.momentum,
            {%- endif %}
            {%- if "counter_halflife" in args.split_function_arg_names %}
            counter_halflife=optimizer_args.counter_halflife,
            {%- endif %}
            {%- if "adjustment_iter" in args.split_function_arg_names %}
            adjustment_iter=optimizer_args.adjustment_iter,
            {%- endif %}
            {%- if "adjustment_ub" in args.split_function_arg_names %}
            adjustment_ub=optimizer_args.adjustment_ub,
            {%- endif %}
            {%- if "learning_rate_mode" in args.split_function_arg_names %}
            learning_rate_mode=optimizer_args.learning_rate_mode,
            {%- endif %}
            {%- if "grad_sum_decay" in args.split_function_arg_names %}
            grad_sum_decay=optimizer_args.grad_sum_decay,
            {%- endif %}
            {%- if "tail_id_threshold" in args.split_function_arg_names %}
            tail_id_threshold=optimizer_args.tail_id_threshold,
            {%- endif %}
            {%- if "is_tail_id_thresh_ratio" in args.split_function_arg_names %}
            is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio,
            {%- endif %}
            {%- if "weight_norm_coefficient" in args.split_function_arg_names %}
            weight_norm_coefficient=optimizer_args.weight_norm_coefficient,
            {%- endif %}
            {%- if "lower_bound" in args.split_function_arg_names %}
            lower_bound=optimizer_args.lower_bound,
            {%- endif %}
            {%- if "regularization_mode" in args.split_function_arg_names %}
            regularization_mode=optimizer_args.regularization_mode,
            {%- endif %}
            {%- if "max_norm" in args.split_function_arg_names %}
            max_norm=optimizer_args.max_norm,
            {%- endif %}
            # momentum1
            {%- if "momentum1_dev" in args.split_function_arg_names %}
            momentum1_host=momentum1.host,
            momentum1_offsets=momentum1.offsets,
            momentum1_placements=momentum1.placements,
            {%- endif %}
            # momentum2
            {%- if "momentum2_dev" in args.split_function_arg_names %}
            momentum2_host=momentum2.host,
            momentum2_offsets=momentum2.offsets,
            momentum2_placements=momentum2.placements,
            {%- endif %}
            # prev_iter
            {%- if "prev_iter_dev" in args.split_function_arg_names %}
            prev_iter_host=prev_iter.host,
            prev_iter_offsets=prev_iter.offsets,
            prev_iter_placements=prev_iter.placements,
            {%- endif %}
            # row_counter
            {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
            row_counter_host=row_counter.host,
            row_counter_offsets=row_counter.offsets,
            row_counter_placements=row_counter.placements,
            {%- endif %}
            # iter
            {%- if "iter" in args.split_function_arg_names %}
            iter=iter,
            {%- endif %}
            # max counter
            {%- if "max_counter" in args.split_function_arg_names %}
            max_counter=max_counter,
            {%- endif %}
        )
        if vbe:
            output_new = torch.empty([vbe_metadata.output_size], dtype=output.dtype, device=output.device)
            B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
            assert isinstance(B_offsets_rank_per_feature, torch.Tensor)
            output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
            assert isinstance(output_offsets_feature_rank, torch.Tensor)
            R = B_offsets_rank_per_feature.size(1) - 1
            for r in range(R):
                D_offset = 0
                for t in range(T):
                    o_begin = output_offsets_feature_rank[r * T + t].item()
                    o_end = output_offsets_feature_rank[r * T + t + 1].item()
                    D = common_args.D_offsets[t + 1].item() - common_args.D_offsets[t].item()
                    b_begin = B_offsets_rank_per_feature[t][r].item()
                    b_end = B_offsets_rank_per_feature[t][r + 1].item()
                    assert o_end - o_begin == (b_end - b_begin) * D
                    output_new[o_begin : o_end] = output[b_begin : b_end, D_offset : D_offset + D].flatten()
                    D_offset += D
            return output_new
        else:
            return output
    {%- if not has_gpu_support %}
    else:
        assert False, "{{ optimizer }} has only CPU support. host_weights.numel() must be greater than 0."
    {%- endif %}
    {%- endif %}

    {%- if has_gpu_support %}

    {%- if ssd %}
    ssd_tensors = []
    {%- for tensor in ssd_tensors %}
    assert "{{ tensor }}" in common_args.ssd_tensors, (
        "{{ tensor }} must be in common_args.ssd_tensors. "
        "Please check the backend version"
    )
    ssd_tensors.append(common_args.ssd_tensors["{{ tensor }}"])
    {%- endfor %}
    {%- endif %}

    return torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
        # common_args
        {%- if not dense %}
        placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
        {%- endif %}
        dev_weights=common_args.dev_weights,
        uvm_weights=common_args.uvm_weights,
        lxu_cache_weights=common_args.lxu_cache_weights,
        weights_placements=common_args.weights_placements,
        weights_offsets=common_args.weights_offsets,
        D_offsets=common_args.D_offsets,
        total_D=common_args.total_D,
        max_D=common_args.max_D,
        hash_size_cumsum=common_args.hash_size_cumsum,
        total_hash_size_bits=common_args.total_hash_size_bits,
        indices=common_args.indices,
        offsets=common_args.offsets,
        pooling_mode=common_args.pooling_mode,
        indice_weights=common_args.indice_weights,
        feature_requires_grad=common_args.feature_requires_grad,
        lxu_cache_locations=common_args.lxu_cache_locations,
        uvm_cache_stats=common_args.uvm_cache_stats,
        {%- if ssd %}
        ssd_tensors=ssd_tensors,
        {%- endif %}
        # VBE metadata
        B_offsets=vbe_metadata.B_offsets,
        vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
        vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
        max_B=vbe_metadata.max_B,
        max_B_feature_rank=vbe_metadata.max_B_feature_rank,
        vbe_output_size=vbe_metadata.output_size,
        # optimizer_args
        {%- if optimizer == "none" %}
        total_hash_size = optimizer_args.total_hash_size,
        {%- else %}
        gradient_clipping = optimizer_args.gradient_clipping,
        max_gradient=optimizer_args.max_gradient,
        stochastic_rounding=optimizer_args.stochastic_rounding,
        {%- endif %} # if optimizer == none
        {%- if "learning_rate" in args.split_function_args_v1 %}
        # V1 interface still accepts learning_rate as float
        learning_rate=optimizer_args.learning_rate,
        {%- endif %}
        {%- if "eps" in args.split_function_arg_names %}
        eps=optimizer_args.eps,
        {%- endif %}
        {%- if "beta1" in args.split_function_arg_names %}
        beta1=optimizer_args.beta1,
        {%- endif %}
        {%- if "beta2" in args.split_function_arg_names %}
        beta2=optimizer_args.beta2,
        {%- endif %}
        {%- if "weight_decay" in args.split_function_arg_names %}
        weight_decay=optimizer_args.weight_decay,
        {%- endif %}
        {%- if "weight_decay_mode" in args.split_function_arg_names %}
        weight_decay_mode=optimizer_args.weight_decay_mode,
        {%- endif %}
        {%- if "eta" in args.split_function_arg_names %}
        eta=optimizer_args.eta,
        {%- endif %}
        {%- if "momentum" in args.split_function_arg_names %}
        momentum=optimizer_args.momentum,
        {%- endif %}
        {%- if "counter_halflife" in args.split_function_arg_names %}
        counter_halflife=optimizer_args.counter_halflife,
        {%- endif %}
        {%- if "adjustment_iter" in args.split_function_arg_names %}
        adjustment_iter=optimizer_args.adjustment_iter,
        {%- endif %}
        {%- if "adjustment_ub" in args.split_function_arg_names %}
        adjustment_ub=optimizer_args.adjustment_ub,
        {%- endif %}
        {%- if "learning_rate_mode" in args.split_function_arg_names %}
        learning_rate_mode=optimizer_args.learning_rate_mode,
        {%- endif %}
        {%- if "grad_sum_decay" in args.split_function_arg_names %}
        grad_sum_decay=optimizer_args.grad_sum_decay,
        {%- endif %}
        {%- if "tail_id_threshold" in args.split_function_arg_names %}
        tail_id_threshold=optimizer_args.tail_id_threshold,
        {%- endif %}
        {%- if "is_tail_id_thresh_ratio" in args.split_function_arg_names %}
        is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio,
        {%- endif %}
        {%- if "weight_norm_coefficient" in args.split_function_arg_names %}
        weight_norm_coefficient=optimizer_args.weight_norm_coefficient,
        {%- endif %}
        {%- if "lower_bound" in args.split_function_arg_names %}
        lower_bound=optimizer_args.lower_bound,
        {%- endif %}
        {%- if "regularization_mode" in args.split_function_arg_names %}
        regularization_mode=optimizer_args.regularization_mode,
        {%- endif %}
        {%- if "max_norm" in args.split_function_arg_names %}
        max_norm=optimizer_args.max_norm,
        {%- endif %}
        # momentum1
        {%- if "momentum1_dev" in args.split_function_arg_names %}
        momentum1_dev=momentum1.dev,
        momentum1_uvm=momentum1.uvm,
        momentum1_offsets=momentum1.offsets,
        momentum1_placements=momentum1.placements,
        {%- endif %}
        # momentum2
        {%- if "momentum2_dev" in args.split_function_arg_names %}
        momentum2_dev=momentum2.dev,
        momentum2_uvm=momentum2.uvm,
        momentum2_offsets=momentum2.offsets,
        momentum2_placements=momentum2.placements,
        {%- endif %}
        # prev_iter
        {%- if "prev_iter_dev" in args.split_function_arg_names %}
        prev_iter_dev=prev_iter.dev,
        prev_iter_uvm=prev_iter.uvm,
        prev_iter_offsets=prev_iter.offsets,
        prev_iter_placements=prev_iter.placements,
        {%- else %}
        {# // explicitly pass only prev_iter_dev for global weight decay #}
        prev_iter_dev=prev_iter_dev,
        {%- endif %}
        # row_counter
        {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
        row_counter_dev=row_counter.dev,
        row_counter_uvm=row_counter.uvm,
        row_counter_offsets=row_counter.offsets,
        row_counter_placements=row_counter.placements,
        {%- endif %}
        {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
        row_counter_dev=row_counter_dev,
        row_counter_uvm=row_counter_uvm,
        row_counter_offsets=row_counter_offsets,
        row_counter_placements=row_counter_placements,
        {%- endif %}
        {%- if "use_rowwise_bias_correction" in args_pt2.unified_pt2.split_function_arg_names %}
        use_rowwise_bias_correction=optimizer_args.use_rowwise_bias_correction,
        {%- endif %}
        # iter
        iter=iter,
        # max counter
        {%- if "max_counter" in args.split_function_arg_names %}
        max_counter=max_counter,
        {%- endif %}
        # total_unique_indices
        {%- if "total_unique_indices" in args.split_function_arg_names %}
        total_unique_indices = total_unique_indices,
        {%- endif %}
        output_dtype=common_args.output_dtype,
        is_experimental=common_args.is_experimental,
        use_uniq_cache_locations_bwd=common_args.use_uniq_cache_locations_bwd,
        use_homogeneous_placements=common_args.use_homogeneous_placements,
        apply_global_weight_decay=apply_global_weight_decay,
        gwd_lower_bound=gwd_lower_bound,
        mixed_D=common_args.mixed_D,
    )
    {%- endif %}
